{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import time\n",
    "from torch.optim import lr_scheduler\n",
    "\n",
    "\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import random_split\n",
    "\n",
    "from torchvision import transforms, models\n",
    "from MyEDFImports import load_all_labels, stages_names_3_outputs, three_stages_transform\n",
    "from tempfile import TemporaryDirectory"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:16:34.803250755Z",
     "start_time": "2023-06-16T16:16:34.757219933Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "data_dir = 'images_(19248, 224, 224)_wav_morlet_sqpy.npy'\n",
    "data_np = np.load(data_dir)\n",
    "targets = load_all_labels()\n",
    "targets = three_stages_transform(targets)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:22.233125552Z",
     "start_time": "2023-06-16T16:16:34.806264355Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "class DatasetFromNp(Dataset):\n",
    "    def __init__(self, data, target, transform=None):\n",
    "        self.data = data\n",
    "        self.target = target\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        x = self.data[index]\n",
    "        # Extending 1channel image to be put to three channels\n",
    "        x = torch.from_numpy(x)\n",
    "        x = torch.unsqueeze(x, 0)\n",
    "        x = x.expand(3, -1, -1)\n",
    "        y = self.target[index]\n",
    "        if self.transform:\n",
    "            x = self.transform(x)\n",
    "        return x, y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.target)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:22.812762459Z",
     "start_time": "2023-06-16T16:18:22.274262484Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "# setting up tranforms for the images (just normalizing pretty much)\n",
    "data_transforms = transforms.Compose([\n",
    "    # to tensor can probably be just torch.from_numpy\n",
    "    # transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "dataset_all = DatasetFromNp(data_np, target=targets, transform=data_transforms)\n",
    "\n",
    "generator = torch.Generator()  # .manual_seed(22)\n",
    "train_data, test_data = random_split(dataset_all, [0.8, 0.2], generator=generator)\n",
    "datasets = {'train': train_data, 'val': test_data}\n",
    "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in\n",
    "               ['train', 'val']}\n",
    "dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}\n",
    "# remove a fixed generator for training\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:22.923693197Z",
     "start_time": "2023-06-16T16:18:22.274517615Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "# this imshow should be fixed doesn't work yet\n",
    "def imshow(inp, title=None):\n",
    "    # inp = inp.numpy().transpose((1, 2, 0))\n",
    "    # mean = np.array([0.485, 0.456, 0.406])\n",
    "    # std = np.array([0.229, 0.224, 0.225])\n",
    "    # inp = std * inp + mean\n",
    "    # inp = np.clip(inp, 0, 1)\n",
    "    plt.imshow(inp)\n",
    "    if title is not None:\n",
    "        plt.title(title)\n",
    "    plt.pause(0.001)\n",
    "\n",
    "\n",
    "# not necessarily test_data subset has to be there can be anything else\n",
    "inputs, classes = next(iter(dataloaders['train']))\n",
    "grid = torchvision.utils.make_grid(inputs)\n",
    "# imshow(grid, title= classes)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:23.296612182Z",
     "start_time": "2023-06-16T16:18:22.925947004Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
    "    since = time.time()\n",
    "\n",
    "    # Create a temporary directory to save training checkpoints\n",
    "    with TemporaryDirectory() as tempdir:\n",
    "        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')\n",
    "\n",
    "        torch.save(model.state_dict(), best_model_params_path)\n",
    "        best_acc = 0.0\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            print(f'Epoch {epoch + 1}/{num_epochs}')\n",
    "            print('-' * 10)\n",
    "\n",
    "            # Each epoch has a training and validation phase\n",
    "            for phase in ['train', 'val']:\n",
    "                if phase == 'train':\n",
    "                    model.train()  # Set model to training mode\n",
    "                else:\n",
    "                    model.eval()  # Set model to evaluate mode\n",
    "\n",
    "                running_loss = 0.0\n",
    "                running_corrects = 0\n",
    "\n",
    "                # Iterate over data.\n",
    "                for inputs, labels in dataloaders[phase]:\n",
    "                    # No idea why now for inputs why I need to transfer it to a float from a double\n",
    "                    inputs = inputs.to(device, dtype=torch.float)\n",
    "                    labels = labels.to(device)\n",
    "\n",
    "                    # zero the parameter gradients\n",
    "                    optimizer.zero_grad()\n",
    "\n",
    "                    # forward\n",
    "                    # track history if only in train\n",
    "                    with torch.set_grad_enabled(phase == 'train'):\n",
    "                        outputs = model(inputs)\n",
    "                        _, preds = torch.max(outputs, 1)\n",
    "                        loss = criterion(outputs, labels)\n",
    "\n",
    "                        # backward + optimize only if in training phase\n",
    "                        if phase == 'train':\n",
    "                            loss.backward()\n",
    "                            optimizer.step()\n",
    "\n",
    "                    # statistics\n",
    "                    running_loss += loss.item() * inputs.size(0)\n",
    "                    running_corrects += torch.sum(preds == labels.data)\n",
    "                if phase == 'train':\n",
    "                    scheduler.step()\n",
    "\n",
    "                epoch_loss = running_loss / dataset_sizes[phase]\n",
    "                epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
    "\n",
    "                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')\n",
    "\n",
    "                # deep copy the model\n",
    "                if phase == 'val' and epoch_acc > best_acc:\n",
    "                    best_acc = epoch_acc\n",
    "                    torch.save(model.state_dict(), best_model_params_path)\n",
    "\n",
    "            print()\n",
    "\n",
    "        time_elapsed = time.time() - since\n",
    "        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')\n",
    "        print(f'Best val Acc: {best_acc:4f}')\n",
    "\n",
    "        # load best model weights\n",
    "        model.load_state_dict(torch.load(best_model_params_path))\n",
    "    return model"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:23.305787718Z",
     "start_time": "2023-06-16T16:18:23.302122348Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "data": {
      "text/plain": "512"
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)\n",
    "num_ftrs = model_18.fc.in_features\n",
    "num_ftrs"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:23.538684919Z",
     "start_time": "2023-06-16T16:18:23.306131719Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "model_18.fc = nn.Linear(num_ftrs, 3)\n",
    "model_18.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer_18 = optim.SGD(model_18.parameters(), lr=0.001, momentum=0.9)\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_18, step_size=7, gamma=0.1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:18:24.891685105Z",
     "start_time": "2023-06-16T16:18:23.538261247Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/25\n",
      "----------\n",
      "train Loss: 0.9668 Acc: 0.6209\n",
      "val Loss: 0.9101 Acc: 0.6345\n",
      "\n",
      "Epoch 2/25\n",
      "----------\n",
      "train Loss: 0.9334 Acc: 0.6362\n",
      "val Loss: 1.1156 Acc: 0.6345\n",
      "\n",
      "Epoch 3/25\n",
      "----------\n",
      "train Loss: 0.8736 Acc: 0.6588\n",
      "val Loss: 1.1564 Acc: 0.6345\n",
      "\n",
      "Epoch 4/25\n",
      "----------\n",
      "train Loss: 0.8175 Acc: 0.6785\n",
      "val Loss: 1.0303 Acc: 0.6345\n",
      "\n",
      "Epoch 5/25\n",
      "----------\n",
      "train Loss: 0.7721 Acc: 0.6936\n",
      "val Loss: 3.3806 Acc: 0.6345\n",
      "\n",
      "Epoch 6/25\n",
      "----------\n",
      "train Loss: 0.7399 Acc: 0.7041\n",
      "val Loss: 3.7381 Acc: 0.6345\n",
      "\n",
      "Epoch 7/25\n",
      "----------\n",
      "train Loss: 0.6984 Acc: 0.7243\n",
      "val Loss: 0.9348 Acc: 0.6345\n",
      "\n",
      "Epoch 8/25\n",
      "----------\n",
      "train Loss: 0.5883 Acc: 0.7689\n",
      "val Loss: 1.2038 Acc: 0.1572\n",
      "\n",
      "Epoch 9/25\n",
      "----------\n",
      "train Loss: 0.5564 Acc: 0.7852\n",
      "val Loss: 1.2773 Acc: 0.1839\n",
      "\n",
      "Epoch 10/25\n",
      "----------\n",
      "train Loss: 0.5353 Acc: 0.7943\n",
      "val Loss: 1.7858 Acc: 0.1816\n",
      "\n",
      "Epoch 11/25\n",
      "----------\n",
      "train Loss: 0.5185 Acc: 0.7995\n",
      "val Loss: 1.0324 Acc: 0.6376\n",
      "\n",
      "Epoch 12/25\n",
      "----------\n",
      "train Loss: 0.5015 Acc: 0.8064\n",
      "val Loss: 25.1963 Acc: 0.1839\n",
      "\n",
      "Epoch 13/25\n",
      "----------\n",
      "train Loss: 0.4797 Acc: 0.8207\n",
      "val Loss: 36.3123 Acc: 0.1839\n",
      "\n",
      "Epoch 14/25\n",
      "----------\n",
      "train Loss: 0.4602 Acc: 0.8218\n",
      "val Loss: 1.0254 Acc: 0.6368\n",
      "\n",
      "Epoch 15/25\n",
      "----------\n",
      "train Loss: 0.4077 Acc: 0.8463\n",
      "val Loss: 42.2782 Acc: 0.1839\n",
      "\n",
      "Epoch 16/25\n",
      "----------\n",
      "train Loss: 0.3957 Acc: 0.8514\n",
      "val Loss: 2.8872 Acc: 0.2307\n",
      "\n",
      "Epoch 17/25\n",
      "----------\n",
      "train Loss: 0.3931 Acc: 0.8531\n",
      "val Loss: 3.3171 Acc: 0.1816\n",
      "\n",
      "Epoch 18/25\n",
      "----------\n",
      "train Loss: 0.3856 Acc: 0.8568\n",
      "val Loss: 66.5642 Acc: 0.1839\n",
      "\n",
      "Epoch 19/25\n",
      "----------\n",
      "train Loss: 0.3791 Acc: 0.8590\n",
      "val Loss: 3.5807 Acc: 0.1816\n",
      "\n",
      "Epoch 20/25\n",
      "----------\n",
      "train Loss: 0.3736 Acc: 0.8615\n",
      "val Loss: 55.0418 Acc: 0.1839\n",
      "\n",
      "Epoch 21/25\n",
      "----------\n",
      "train Loss: 0.3703 Acc: 0.8645\n",
      "val Loss: 61.7360 Acc: 0.1839\n",
      "\n",
      "Epoch 22/25\n",
      "----------\n",
      "train Loss: 0.3587 Acc: 0.8693\n",
      "val Loss: 1.8015 Acc: 0.4422\n",
      "\n",
      "Epoch 23/25\n",
      "----------\n",
      "train Loss: 0.3604 Acc: 0.8670\n",
      "val Loss: 1.1058 Acc: 0.6402\n",
      "\n",
      "Epoch 24/25\n",
      "----------\n",
      "train Loss: 0.3602 Acc: 0.8684\n",
      "val Loss: 23.0243 Acc: 0.1878\n",
      "\n",
      "Epoch 25/25\n",
      "----------\n",
      "train Loss: 0.3577 Acc: 0.8675\n",
      "val Loss: 1.8257 Acc: 0.4362\n",
      "\n",
      "Training complete in 27m 51s\n",
      "Best val Acc: 0.640166\n"
     ]
    }
   ],
   "source": [
    "model_152 = train_model(model_18, criterion, optimizer_18, exp_lr_scheduler,\n",
    "                        num_epochs=25)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:46:15.554434275Z",
     "start_time": "2023-06-16T16:18:24.892866089Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "dir_saved_models='saved_models'\n",
    "torch.save(model_18, dir_saved_models+f'/ResNet18_16_06')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:46:15.619419881Z",
     "start_time": "2023-06-16T16:46:15.554217214Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "model_loaded = torch.load('saved_models/ResNet152_16_06')\n",
    "model_loaded.to(device)\n",
    "model_loaded.eval()\n",
    "dl_for_all = torch.utils.data.DataLoader(dataset_all, batch_size=1, num_workers=4)\n",
    "predictions = []\n",
    "with torch.no_grad():\n",
    "    for i,l in dl_for_all:\n",
    "        i = i.to(device, dtype=torch.float32)\n",
    "        l.to(device)\n",
    "        all_outputs = model_loaded(i)\n",
    "        predictions.append(all_outputs)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:51:04.382123096Z",
     "start_time": "2023-06-16T16:46:15.623411011Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "model_loaded.cpu()\n",
    "del model_loaded\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:51:04.568376443Z",
     "start_time": "2023-06-16T16:51:04.422209544Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "data": {
      "text/plain": "19248"
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(predictions)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:51:04.578487868Z",
     "start_time": "2023-06-16T16:51:04.570221128Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "prrrred = [ torch.max(prediction, 1)[1] for prediction in predictions]\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:51:04.780517585Z",
     "start_time": "2023-06-16T16:51:04.576213243Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [
    {
     "data": {
      "text/plain": "False"
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prrrred == targets"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:51:04.790318619Z",
     "start_time": "2023-06-16T16:51:04.782980151Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [
    {
     "data": {
      "text/plain": "0.9502805486284289"
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "good = 0\n",
    "for i in range(len(targets)):\n",
    "    if targets[i] == prrrred[i]:\n",
    "        good +=1\n",
    "good / len(targets)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:51:05.158579316Z",
     "start_time": "2023-06-16T16:51:04.830230818Z"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
